//
// Created by pulsarv on 19-12-8.
//

#include <vehicle_attributes_classifier.h>
#include <src/include/vehicle_attributes_classifier.h>


namespace MobileSearch {
    VehicleAttributesClassifier::VehicleAttributesClassifier(InferenceEngine::Core &ie,
                                                                 const std::string &deviceName,
                                                                 const std::string &xmlPath, const bool autoResize,
                                                                 const std::map<std::string, std::string> &pluginConfig,
                                                                 const std::string& m_va_path, bool auto_resize
    ) : ie_(ie) {
        InferenceEngine::CNNNetReader attributesNetReader;
        attributesNetReader.ReadNetwork(m_va_path);
        std::string attributesBinFileName = fileNameNoExt(m_va_path) + ".bin";
        attributesNetReader.ReadWeights(attributesBinFileName);

        InferenceEngine::InputsDataMap attributesInputInfo(attributesNetReader.getNetwork().getInputsInfo());
        if (attributesInputInfo.size() != 1) {
            throw std::logic_error("Vehicle Attribs topology should have only one input");
        }
        InferenceEngine::InputInfo::Ptr &attributesInputInfoFirst = attributesInputInfo.begin()->second;
        attributesInputInfoFirst->setPrecision(InferenceEngine::Precision::U8);
        if (auto_resize) {
            attributesInputInfoFirst->getPreProcess().setResizeAlgorithm(
                    InferenceEngine::ResizeAlgorithm::RESIZE_BILINEAR);
            attributesInputInfoFirst->setLayout(InferenceEngine::Layout::NHWC);
        } else {
            attributesInputInfoFirst->setLayout(InferenceEngine::Layout::NCHW);
        }

        attributesInputName = attributesInputInfo.begin()->first;

        InferenceEngine::OutputsDataMap attributesOutputInfo(attributesNetReader.getNetwork().getOutputsInfo());
        if (attributesOutputInfo.size() != 2) {
            throw std::logic_error("Vehicle Attribs Network expects networks having two outputs");
        }
        auto it = attributesOutputInfo.begin();
        it->second->setPrecision(InferenceEngine::Precision::FP32);
        outputNameForColor = (it++)->second->getName();  // color is the first output
        it->second->setPrecision(InferenceEngine::Precision::FP32);
        outputNameForType = (it)->second->getName();  // type is the second output.

        net = ie_.LoadNetwork(attributesNetReader.getNetwork(), deviceName, pluginConfig);

    }

    InferenceEngine::InferRequest VehicleAttributesClassifier::createInferRequest() {
        return net.CreateInferRequest();
    }

    void VehicleAttributesClassifier::setImage(InferenceEngine::InferRequest &inferRequest, const cv::Mat &img,
                                                 const cv::Rect& vehicleRect) {

        InferenceEngine::Blob::Ptr roiBlob = inferRequest.GetBlob(attributesInputName);
        if (InferenceEngine::Layout::NHWC == roiBlob->getTensorDesc().getLayout()) {  // autoResize is set
            InferenceEngine::ROI cropRoi{0, static_cast<size_t>(vehicleRect.x), static_cast<size_t>(vehicleRect.y),
                                         static_cast<size_t>(vehicleRect.width),
                                         static_cast<size_t>(vehicleRect.height)};
            InferenceEngine::Blob::Ptr frameBlob = wrapMat2Blob(img);
            InferenceEngine::Blob::Ptr roiBlob = make_shared_blob(frameBlob, cropRoi);
            inferRequest.SetBlob(attributesInputName, roiBlob);
        } else {
            const cv::Mat &vehicleImage = img(vehicleRect);
            matU8ToBlob<uint8_t>(vehicleImage, roiBlob);
        }
    }

    std::pair<std::string, std::string>
    VehicleAttributesClassifier::getResults(InferenceEngine::InferRequest &inferRequest) {
        static const std::string colors[] = {
                "white", "gray", "yellow", "red", "green", "blue", "black"
        };
        static const std::string types[] = {
                "car", "van", "truck", "bus"
        };

        // 7 possible colors for each vehicle and we should select the one with the maximum probability
        auto colorsValues = inferRequest.GetBlob(outputNameForColor)->buffer().as<float *>();
        // 4 possible types for each vehicle and we should select the one with the maximum probability
        auto typesValues = inferRequest.GetBlob(outputNameForType)->buffer().as<float *>();

        const auto color_id = std::max_element(colorsValues, colorsValues + 7) - colorsValues;
        const auto type_id = std::max_element(typesValues, typesValues + 4) - typesValues;
        return std::pair<std::string, std::string>(colors[color_id], types[type_id]);

    }

}
